# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""tvm.contrib.msc.core.tools.base_tool"""

import os
import copy
import logging
from itertools import product
from typing import List, Iterable, Any, Tuple, Dict
import numpy as np

import tvm
from tvm.contrib.msc.core.ir import MSCGraph, WeightGraph, MSCJoint, WeightJoint, MSCTensor
from tvm.contrib.msc.core.utils.namespace import MSCFramework
from tvm.contrib.msc.core import utils as msc_utils
from tvm.contrib.msc.core import _ffi_api


class ToolType(object):
    """Enum all msc tool types"""

    BASE = "base"
    WEIGHT = "weight"
    PRUNER = "pruner"
    QUANTIZER = "quantizer"
    DISTILLER = "distiller"
    TRACKER = "tracker"
    ALL = [PRUNER, QUANTIZER, DISTILLER, TRACKER]

    @classmethod
    def all_types(cls) -> List[str]:
        return cls.ALL


class ToolScope(object):
    """Enum all msc tool scope"""

    TEACHER = "teacher"
    STUDENT = "student"


class ToolExecutor(object):
    """Executor for process the tensor

    Parameters
    ----------
    name: str
        The name.
    method: str
        The method for execute.
    config: dict
        The config for execute
    """

    def __init__(self, name: str, method: callable, config: dict = None):
        self._name = name
        self._method = method
        self._config = config or {}

    def __str__(self):
        return "{}({})".format(self._name, self._config)

    def execute(self, *args, **kwargs) -> Any:
        """execute the method

        Parameters
        ----------
        args: list<Any>
            The arguments for run method.
        kwargs: dict<Any>
            The key word arguments for run method.

        Returns
        -------
        plan or tensor:
           The plan generated by method or processed tensor.
        """

        kwargs.update({k: v for k, v in self._config.items() if k not in kwargs})
        return self._method(*args, **kwargs)

    def copy(self, name: str = None, method: callable = None, config: dict = None):
        """Copy a executor

        Parameters
        ----------
        name: str
            The name for new executor.
        method: str
            The method for new execute.
        config: dict
            The config for new execute

        Returns
        -------
        new_executor: ToolExecutor
           The copied executor
        """

        new_config = config or {}
        new_config.update({k: v for k, v in self._config.items() if k not in new_config})
        return ToolExecutor(name or self._name, method or self._method, new_config)

    @property
    def name(self):
        return self._name

    @property
    def config(self):
        return self._config


class ToolStrategy(object):
    """Strategy for process tensor

    Parameters
    ----------
    name: str
        The name.
    tensor_type: str
        The tensor type.
    stage: str
        The init stage
    meta: dict:
        The meta strategy config.
    """

    def __init__(self, name: str, tensor_type: str, stage: str = "default", meta: dict = None):
        self._name = name
        self._tensor_type = tensor_type
        self._stage = stage
        self._executors = {}
        self._meta = meta

    def __str__(self):
        return "{}({} @ {}) ".format(self._name, self._tensor_type, self._stage) + "; ".join(
            ["{}:{}".format(k, v) for k, v in self._executors.items()]
        )

    def inspect(self) -> dict:
        """Get inspect of strategy

        Returns
        -------
        inspect: dict
           The inspect of the strategy.
        """

        return {"{}({})".format(s, self._tensor_type): str(e) for s, e in self._executors.items()}

    def __call__(self, *args, **kwargs) -> Any:
        return self.apply(*args, **kwargs)

    def apply(self, *args, **kwargs) -> Any:
        """Apply the strategy

        Parameters
        ----------
        args: list<Any>
            The arguments for run method.
        kwargs: dict<Any>
            The key word arguments for run method.

        Returns
        -------
        plan or tensot:
           The plan generated by method or processed tensor.
        """

        return self.get_executor().execute(*args, **kwargs)

    def change_stage(self, stage: str):
        """Change the stage of strategy"""

        self._stage = stage

    def add_executor(self, stage: str, executor: ToolExecutor):
        """Add a executor to strategy

        Parameters
        ----------
        stage: str
            The mark of the executor.
        executor: ToolExecutor
            The executor to process tensor.
        """

        self._executors[stage] = executor
        if not self._stage:
            self._stage = stage

    def get_executor(self) -> Tuple[callable, dict]:
        """Get executor of current stage

        Returns
        -------
        executor: tuple<callable, dict>
           The method and config to execute strategy
        """

        if self._stage in self._executors:
            return self._executors[self._stage]
        return self._executors["default"]

    def get_config(self) -> dict:
        """Get the config of current executor"""

        return self.get_executor().config

    def support_stage(self, stage: str) -> bool:
        """Check if the strategy support a stage

        Parameters
        ----------
        stage: str
            The mark of the executor

        Returns
        -------
        support: bool
           Whether the strategy support the strategy
        """

        return stage in self._executors or "default" in self._executors

    def copy(
        self,
        name: str = None,
        tensor_type: str = None,
        stage: str = None,
        configs: Dict[str, dict] = None,
    ):
        """Copy a strategy

        Parameters
        ----------
        name: str
            The name for new strategy
        tensor_type:
            The tensor type for new strategy
        stage: str
            The init stage for new strategy
        configs: dict<str,dict>
            The method config of new executors.

        Returns
        -------
        new_strategy: ToolStrategy
           The copied strategy
        """

        configs = configs or {}
        strategy = ToolStrategy(
            name or self._name, tensor_type or self._tensor_type, stage or self._stage
        )
        for st_name, executor in self._executors.items():
            new_executor = executor.copy(config=configs.get(st_name, {}))
            strategy.add_executor(st_name, new_executor)
        return strategy

    @property
    def meta(self):
        return self._meta


class BaseTool(object):
    """Basic tool of MSC

    Parameters
    ----------
    stage: str
        The stage of tool
    plan_file: str
        The plan file path.
    strategys: list[dict]
        The strategys of the tool.
    cache_processed: bool
        Whether to cache processed tensor.
    options: dict
        The extra options for the tool
    debug_level: int
        The debug level.
    verbose_step: int
        The verbose interval step.
    logger: logging.Logger
        The logger
    """

    def __init__(
        self,
        stage: str,
        plan_file: str,
        strategys: List[dict],
        cache_processed: bool = True,
        options: dict = None,
        debug_level: int = 0,
        verbose_step: int = 50,
        logger: logging.Logger = None,
    ):
        self._stage = stage
        if os.path.isfile(plan_file):
            self._plan = msc_utils.load_dict(plan_file)
        else:
            self._plan = {}
        self._strategys = self._parse_strategys(msc_utils.copy_dict(strategys))
        self._cache_processed = cache_processed
        self._options = options or {}
        self._debug_level = debug_level
        self._verbose_step = verbose_step
        self._logger = logger or msc_utils.get_global_logger()
        title = "{}.SETUP({} @ {})".format(self.tool_type().upper(), self._stage, self.framework())
        self._logger.info(msc_utils.msg_block(title, self.setup(), width=0))
        if self._debug_level >= 3 and self._plan:
            title = "{}.PLAN".format(self.tool_type().upper())
            self._logger.debug(msc_utils.msg_block(title, self._plan))

    def setup(self) -> dict:
        """Setup the tool

        Returns
        -------
        info: dict
            The setup info.
        """

        self._tensor_cache = {}
        self._enabled, self._is_training = True, True
        self._graphs, self._weights = [], {}
        self._graph_id, self._forward_cnt = 0, 0
        self._processed_tensor = {}
        return {
            "style": self.tool_style(),
            "strategys": {k: v.inspect() for k, v in self._strategys.items()},
            "cache_processed": self._cache_processed,
            "options": self._options,
            "planed_num": len(self._plan),
            "verbose_step": self._verbose_step,
            "debug_level": self._debug_level,
        }

    def _parse_strategys(self, strategy_list: List[dict]) -> Dict[str, ToolStrategy]:
        """Parse the strategy to get valid strategy

        Parameters
        -------
        strategy_list: list<dict>
            The given strategys

        Returns
        -------
        strategys: dict<str, ToolStrategy>
            The parsed strategy.
        """

        strategys = {}
        assert isinstance(strategy_list, list) and all(
            isinstance(s, dict) for s in strategy_list
        ), "ToolStrategy should be given as list of dict"
        for strategy in strategy_list:
            meta_strategy = msc_utils.copy_dict(strategy)
            method_cls_name = strategy.pop("method_cls") if "method_cls" in strategy else "default"
            method_cls = msc_utils.get_registered_tool_method(
                self.framework(), self.tool_type(), method_cls_name
            )
            method_name = strategy.pop("method") if "method" in strategy else "default"
            method = None
            if hasattr(method_cls, method_name):
                method = getattr(method_cls, method_name)
            if not method:
                default_cls = msc_utils.get_registered_tool_method(
                    MSCFramework.MSC, self.tool_type(), method_cls_name
                )
                if hasattr(default_cls, method_name):
                    method = getattr(default_cls, method_name)
            if not method:
                method = msc_utils.get_registered_func(method_name)
            assert method, "Can not find method with " + str(method_name)
            tensor_types = strategy.pop("tensor_types") if "tensor_types" in strategy else ["all"]
            if "op_types" in strategy:
                op_types = strategy.pop("op_types")
                marks = [("{}.{}".format(s, t), t) for s, t in product(op_types, tensor_types)]
            elif "op_names" in strategy:
                op_names = strategy.pop("op_names")
                marks = [("{}.{}".format(s, t), t) for s, t in product(op_names, tensor_types)]
            elif "tensor_names" in strategy:
                tensor_names = strategy.pop("tensor_names")
                marks = [(n, "all") for n in tensor_names]
            else:
                marks = [("default", "all")]
            stages = strategy.pop("stages") if "stages" in strategy else ["default"]
            for mark, t_type in marks:
                if mark not in strategys:
                    strategys[mark] = ToolStrategy(mark, t_type, self._stage, meta_strategy)
                for stage in stages:
                    strategys[mark].add_executor(
                        stage, ToolExecutor(method_name, method, copy.deepcopy(strategy))
                    )
        return strategys

    def reset(
        self,
        graphs: List[MSCGraph],
        weights: List[Dict[str, tvm.nd.array]],
        cache_dir: msc_utils.MSCDirectory = None,
    ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]:
        """Reset the tool with graphs and weights

        Parameters
        ----------
        graphs: list<MSCgraph>
            The msc graphs.
        weights: list<dict<str, tvm.nd.array>>
            The weights
        cache_dir: MSCDirectory
            cache path for save/load info

        Returns
        -------
        graphs: list<MSCgraph>
            The msc graphs.
        weights: list<dict<str, tvm.nd.array>>
            The weights
        """

        self._forward_cnt = 0
        self._tensor_cache = {}
        if cache_dir and os.path.isfile(cache_dir.relpath("cache_info.json")):
            cache_info = msc_utils.load_dict(cache_dir.relpath("cache_info.json"))
        else:
            cache_info = {}
        if self.tool_type() in cache_info:
            self.load_cache(cache_dir, cache_info[self.tool_type()])
        self._graphs, weights = self._reset(graphs, weights)
        self._weights = {}
        for sub_weights in weights:
            self._weights.update(sub_weights)
        self._logger.debug(
            "%s reset %d graphs, %d weights",
            self.tool_type(),
            len(self._graphs),
            len(self._weights),
        )
        return self._graphs, weights

    def _reset(
        self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]]
    ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]:
        """Reset the tool

        Parameters
        ----------
        graphs: list<MSCgraph>
            The msc graphs.
        weights: list<dict<str, tvm.nd.array>>
            The weights

        Returns
        -------
        graphs: list<MSCgraph>
            The msc graphs.
        weights: list<dict<str, tvm.nd.array>>
            The weights
        """

        return graphs, weights

    def change_stage(self, stage: str):
        """Change the stage of tool and strategy"""

        self._stage = stage
        for strategy in self._strategys.values():
            strategy.change_stage(stage)

    def change_logger(self, logger: logging.Logger):
        """Change the logger of tool"""

        self._logger = logger

    def destory(self):
        """Destory tool"""

        self._graphs, self._weights = [], {}

    def load_cache(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict):
        """Save runner to cache

        Parameters
        -------
        cache_dir: MSCDirectory
            cache path for save/load info
        cache_info: dict
            The cache_info
        """

        return None

    def save_cache(self, cache_dir: msc_utils.MSCDirectory) -> dict:
        """Save runner to cache

        Parameters
        -------
        cache_dir: MSCDirectory
            cache path for save/load info

        Returns
        -------
        cache_info: dict
            The cache_info.
        """

        return {}

    def execute_before_build(self, *args, **kwargs):
        """Execute before model build

        Parameters
        ----------
        args: list<Any>
            The arguments for model build.
        kwargs: dict<Any>
            The key word arguments for model build.
        """

        if self._enabled:
            self._graph_id = self._infer_graph_id(kwargs)
            self._processed_tensor = {}
            if self.on_debug(3, in_forward=False):
                self._logger.debug("%sStart Build", self.msg_mark(in_forward=False))
            self._execute_before_build(*args, **kwargs)

    def _execute_before_build(self, *args, **kwargs):
        """Execute before model build

        Parameters
        ----------
        args: list<Any>
            The arguments for model build.
        kwargs: dict<Any>
            The key word arguments for model build.
        """

        return None

    def execute_after_build(self, output: Any) -> Any:
        """Execute after model build

        Parameters
        ----------
        output: Any
            The output reference of the model.

        Returns
        -------
        output: Any
           The modified output reference.
        """

        if self._enabled:
            output = self._execute_after_build(output)
            if self.on_debug(3, in_forward=False):
                self._logger.debug("%sEnd Build", self.msg_mark(in_forward=False))
        return output

    def _execute_after_build(self, output: Any) -> Any:
        """Execute after model build

        Parameters
        ----------
        output: Any
            The output reference of the model.

        Returns
        -------
        output: Any
           The modified output reference.
        """

        return output

    def execute_before_forward(self, *args, **kwargs):
        """Execute before model forward

        Parameters
        ----------
        args: list<Any>
            The arguments for model forward.
        kwargs: dict<Any>
            The key word arguments for model forward.
        """

        if self._enabled:
            self._graph_id = self._infer_graph_id(kwargs)
            self._processed_tensor = {}
            if self.on_debug(3):
                self._logger.debug("%sStart Forward", self.msg_mark())
            self._execute_before_forward(*args, **kwargs)

    def _execute_before_forward(self, *args, **kwargs):
        """Execute before model forward

        Parameters
        ----------
        args: list<Any>
            The arguments for model forward.
        kwargs: dict<Any>
            The key word arguments for model forward.
        """

        return None

    def execute_after_forward(self, output: Any) -> Any:
        """Execute after model forward

        Parameters
        ----------
        output: Any
            The output reference of the model.

        Returns
        -------
        output: Any
           The modified output reference.
        """

        if self._enabled:
            output = self._execute_after_forward(output)
            if self.on_debug(3):
                self._logger.debug(
                    "%sEnd Forward, process %d tensors",
                    self.msg_mark(),
                    len(self._processed_tensor),
                )
            self._forward_cnt += 1
        return output

    def _execute_after_forward(self, output: Any) -> Any:
        """Execute after model forward

        Parameters
        ----------
        output: Any
            The output reference of the model.

        Returns
        -------
        output: Any
           The modified output reference.
        """

        return output

    def process_tensor(self, tensor: Any, name: str, consumer: str, scope: str) -> Any:
        """Process tensor

        Parameters
        -------
        tensor: Any
            Tensor in framework
        name: str
            The name of the tensor.
        consumer: str
            The name of the consumer.
        scope: str
            The scope mark teacher| student| null

        Returns
        -------
        tensor: Any
            The processed tensor.
        """

        if not self._enabled:
            return tensor
        if not self._support_scope(scope):
            return tensor
        strategys = self._get_tensor_strategys(name, consumer)
        t_mark = ".".join([s.get_executor().name for s in strategys])
        if scope:
            t_mark += "." + scope
        cached_tensor = self._get_processed(name, consumer, t_mark)
        if cached_tensor is not None:
            self.debug_tensor(cached_tensor, name, consumer, "cached({})".format(t_mark))
            return cached_tensor
        process = self._get_tensor_cache(name, consumer, "process")
        if process is None:
            process = self._check_tensor(name, consumer)
            self._save_tensor_cache(name, consumer, "process", process)
            if process and self.on_debug(3):
                self._logger.debug("%sprocess tensor %s-%s", self.msg_mark(), name, consumer)
        if not process:
            return tensor
        tensor = self._process_tensor(tensor, name, consumer, scope, strategys)
        self._save_processed(name, consumer, tensor, t_mark)
        self.debug_tensor(tensor, name, consumer, "processed({})".format(t_mark))
        return tensor

    def _support_scope(self, scope: str) -> bool:
        """Check if the scope si supported

        Parameters
        -------
        scope: str
            The scope mark, should be null or ToolScope

        Returns
        -------
        vaild: bool
            Whether to process the tensor.
        """

        if not scope:
            return True
        return scope != ToolScope.TEACHER

    def _get_processed(self, name: str, consumer: str, strategy_mark: str) -> Any:
        """Get cached processed tensor

        Parameters
        -------
        name: str
            The name of the tensor.
        consumer: str
            The name of the consumer.
        strategy_mark: str
            The sstrategy mark.

        Returns
        -------
        processed_tensor
            The cached processed tensor.
        """

        if self._cache_processed:
            return self._processed_tensor.get(name + "." + strategy_mark)
        return None

    def _save_processed(self, name: str, consumer: str, tensor: Any, strategy_mark: str):
        """Save cached processed tensor

        Parameters
        -------
        name: str
            The name of the tensor.
        consumer: str
            The name of the consumer.
        tensor: Any
            The processed tensor
        strategy_mark: str
            The sstrategy mark.
        """

        if self._cache_processed:
            self._processed_tensor[name + "." + strategy_mark] = tensor
        else:
            self._processed_tensor[self.to_tensor_id(name, consumer)] = None

    def _check_tensor(self, name: str, consumer: str) -> bool:
        """Check if the tensor should be processed

        Parameters
        -------
        name: str
            The name of the tensor.
        consumer: str
            The name of the consumer.

        Returns
        -------
        vaild: bool
            Whether to process the tensor.
        """

        strategys = self._get_tensor_strategys(name, consumer)
        return len(strategys) > 0

    def _process_tensor(
        self, tensor: Any, name: str, consumer: str, scope: str, strategys: List[ToolStrategy]
    ) -> Any:
        """Process tensor

        Parameters
        -------
        tensor: Any
            Tensor in framework
        name: str
            The name of the tensor.
        consumer: str
            The name of the consumer.
        scope: str
            The scope mark teacher| student| null.
        strategys: list<ToolStrategy>
            The strategys for the tensor.

        Returns
        -------
        tensor: Any
            The processed tensor.
        """

        return tensor

    def create_tasks(self, **kwargs) -> List[dict]:
        """Create tasks for gym

        Parameters
        ----------
        kwargs: dict
           The kwargs for create tasks.

        Returns
        -------
        tasks: list<dict>
            The tasks.
        """

        return []

    def config_generate(self, generate_config: Dict[str, Any]) -> Dict[str, Any]:
        """Update the generate configs

        Parameters
        ----------
        generate_config: dict<str, Any>
            The generate_config.

        Returns
        -------
        generate_config: dict<str, Any>
            The updated generate_config.
        """

        return generate_config

    def visualize(self, visual_dir: msc_utils.MSCDirectory):
        """Visualize MSCGraphs

        Parameters
        -------
        visual_dir: MSCDirectory
            Visualize path for saving graph
        """

        return None

    def set_plan(self, plan: dict):
        """Set the plan

        Parameters
        ----------
        plan: dict
            The new plan.
        """

        if self._plan:
            self._plan = msc_utils.update_dict(self._plan, plan)
        else:
            self._plan = plan

    def finalize(self) -> dict:
        """Get the plan"""

        return self._plan

    def enable(self):
        """Enable the tool"""

        self._enabled = True

    def disable(self):
        """Disable the tool"""

        self._enabled = False

    def train(self):
        """Set the tool to train mode"""

        self._is_training = True

    def eval(self):
        """Set the tool to eval mode"""

        self._is_training = False

    def to_tensor_id(self, name: str, consumer: str) -> str:
        """Concat name to unique id

        Parameters
        ----------
        name: str
            The name of tensor.
        consumer: str
            The name of consumer.

        Returns
        -------
        tensor_id: str
           The unique name of edge.
        """

        return "{}-c-{}".format(name, consumer)

    def from_tensor_id(self, tensor_id: str) -> Tuple[str]:
        """Split name from unique id

        Parameters
        ----------
        tensor_id: str
           The unique name of edge.

        Returns
        -------
        name: str
            The name of tensor.
        consumer: str
            The name of consumer.
        """

        return tensor_id.split("-c-")

    def is_weight(self, name: str) -> bool:
        """Check if the tensor is weight

        Parameters
        ----------
        name: str
           The name of tensor.

        Returns
        -------
        is_weight: bool
            Whether the name is weight.
        """

        return name in self._weights

    def on_debug(self, debug_level: int = 1, in_forward: bool = True) -> bool:
        """Check if should log

        Parameters
        -------
        debug_level: int
           The given debug_level.
        in_forward: bool
            Whether to check forward_cnt.

        Returns
        -------
        on_debug: bool
            Whether to log debug info.
        """

        if in_forward and self._forward_cnt % self._verbose_step != 0:
            return False
        return self._debug_level >= debug_level

    def msg_mark(self, in_forward: bool = True) -> str:
        """Get the debug title

        Parameters
        -------
        in_forward: bool
            Whether to add forward mark.

        Returns
        -------
        msg_mark: str
            Get the debug title.
        """

        title = "{}.G[{}]".format(self.tool_type().upper(), self._graph_id)
        if in_forward:
            title += ".F[{}]".format(self._forward_cnt)
        title += "({}) ".format(self._stage)
        return title

    def debug_tensor(
        self, tensor: Any, name: str, consumer: str, t_mark: str, debug_level: int = 3
    ) -> str:
        """Get the debug tensor info

        Parameters
        -------
        tensor: array_like
            The tensor
        name: str
            The name of tensor.
        consumer: str
            The name of consumer.
        t_mark: str
            The mark of tensor.
        debug_level: int
           The given debug_level.
        """

        if self.on_debug(debug_level):
            self._logger.debug(
                "%s%s %s-%s: %s",
                self.msg_mark(),
                t_mark,
                name,
                consumer,
                msc_utils.inspect_array(tensor),
            )

    def _infer_graph_id(self, kwargs: dict) -> int:
        """Infer graph id from kwargs

        Parameters
        ----------
        kwargs: dict
           The kwargs for execute.
        """

        if "graph_id" in kwargs:
            return kwargs.pop("graph_id")
        if "graph_name" in kwargs:
            name = kwargs.pop("graph_name")
            for idx, g in enumerate(self._graphs):
                if g.name == name:
                    return idx
        return 0

    def get_nodes(self) -> Iterable[MSCJoint]:
        """Get all the nodes in the graphs.

        Returns
        -------
        nodes: generator<MSCJoint>
            The generator of nodes.
        """

        for g in self._graphs:
            for n in g.get_nodes():
                yield n

    def find_node(self, name: str) -> MSCJoint:
        """Find node by name.

        Parameters
        ----------
        name: string
            The name of the node.

        Returns
        -------
        node: MSCJoint
            The found node.
        """

        for g in self._graphs:
            if g.has_node(name):
                return g.find_node(name)
        raise Exception("Can not find node {} from {} graphs".format(name, len(self._graphs)))

    def find_tensor(self, name: str) -> MSCTensor:
        """Find tensor by name.

        Parameters
        ----------
        name: string
            The name of the tensor.

        Returns
        -------
        node: MSCTensor
            The found tensor.
        """

        for g in self._graphs:
            if g.has_tensor(name):
                return g.find_tensor(name)
        raise Exception("Can not find tensor {} from {} graphs".format(name, len(self._graphs)))

    def find_producer(self, name: str) -> MSCJoint:
        """Find producer by tensor_name .

        Parameters
        ----------
        name: string
            The name of the tensor.

        Returns
        -------
        node: MSCJoint
            The found prducer.
        """

        for g in self._graphs:
            if g.has_tensor(name):
                return g.find_producer(name)
        raise Exception(
            "Can not find producer of {} from {} graphs".format(name, len(self._graphs))
        )

    def find_consumers(self, name: str) -> List[MSCJoint]:
        """Find consumers by tensor_name.

        Parameters
        ----------
        name: string
            The name of the tensor.

        Returns
        -------
        node: list<MSCJoint>
            The found consumers.
        """

        for g in self._graphs:
            if g.has_tensor(name):
                return g.find_consumers(name)
        raise Exception(
            "Can not find consumers of {} from {} graphs".format(name, len(self._graphs))
        )

    def get_data(self, name: str) -> np.ndarray:
        """Get the data by name

        Parameters
        -------
        name: str
            The tensor name

        Returns
        -------
        data: np.ndarray
            The data.
        """

        if name in self._weights:
            return msc_utils.cast_array(self._weights[name])
        raise Exception("Can not find data {} from {} weights".format(name, len(self._weights)))

    def _save_tensor_cache(self, name: str, consumer: str, key: str, value: Any):
        """Save the data to tensor cache

        Parameters
        -------
        name: str
            The tensor name.
        consumer: str
            The name of the consumer.
        key: str
            The data key.
        value: any
            The value to cache.
        """

        tensor_id = self.to_tensor_id(name, consumer)
        if tensor_id not in self._tensor_cache:
            self._tensor_cache[tensor_id] = {}
        self._tensor_cache[tensor_id][key] = value

    def _get_tensor_cache(self, name: str, consumer: str, key: str) -> Any:
        """Get the cached tensor data

        Parameters
        -------
        name: str
            The tensor name.
        consumer: str
            The name of the consumer.
        key: str
            The data key.

        Returns
        -------
        value: any
            The cached value.
        """

        tensor_id = self.to_tensor_id(name, consumer)
        if tensor_id not in self._tensor_cache:
            return None
        return self._tensor_cache[tensor_id].get(key)

    def _get_tensor_strategys(self, name: str, consumer: str) -> List[ToolStrategy]:
        """Get the strategys by name and consumer

        Parameters
        -------
        name: str
            The tensor name.
        consumer: str
            The name of the consumer.

        Returns
        -------
        strategys: list<ToolStrategy>
            The strategys for the tensor.
        """

        tensor_id = self.to_tensor_id(name, consumer)
        mark = "strategy.{}".format(self._stage)
        if mark not in self._tensor_cache.get(tensor_id, {}):
            if self.is_weight(name):
                consumer = self.find_node(consumer)
                name_refs = [
                    consumer.name + ".weight",
                    consumer.optype + ".weight",
                    consumer.optype + ".all",
                ]
            elif consumer == "exit":
                producer = self.find_producer(name)
                name_refs = [
                    producer.name + ".output",
                    producer.optype + ".output",
                    producer.optype + ".all",
                ]
            else:
                consumer = self.find_node(consumer)
                producer = self.find_producer(name)
                name_refs = [
                    producer.name + ".output",
                    producer.optype + ".output",
                    producer.optype + ".all",
                    consumer.name + ".input",
                    consumer.optype + ".input",
                    consumer.optype + ".all",
                ]
            strategys = []
            tensor_strategy = self._strategys.get(tensor_id)
            if tensor_strategy and tensor_strategy.support_stage(self._stage):
                strategys.append(tensor_strategy)
            if not strategys:
                for n in name_refs:
                    if n in self._strategys and self._strategys[n].support_stage(self._stage):
                        strategys.append(self._strategys[n])
            d_strategy = self._strategys.get("default")
            if not strategys and d_strategy and d_strategy.support_stage(self._stage):
                strategys.append(d_strategy)
            self._save_tensor_cache(name, consumer, mark, strategys)
        return self._get_tensor_cache(name, consumer, mark)

    def _get_tensor_strategy(self, name: str, consumer: str) -> ToolStrategy:
        """Get the unique strategy by name and consumer

        Parameters
        -------
        name: str
            The tensor name.
        consumer: str
            The name of the consumer.

        Returns
        -------
        strategy: ToolStrategy
            The unique strategy for the tensor.
        """

        strategys = self._get_tensor_strategys(name, consumer)
        if not strategys:
            return None
        assert len(strategys) == 1, "{} should only has 1 strategy, get {}".format(
            self._stage, strategys
        )
        return strategys[0]

    def get_graph(self):
        return self._graphs[self._graph_id]

    @classmethod
    def tool_type(cls):
        return ToolType.BASE

    @classmethod
    def framework(cls):
        return MSCFramework.MSC

    @classmethod
    def tool_style(cls):
        return "base"


class WeightTool(BaseTool):
    """Basic tool with weight graphs"""

    def setup(self) -> dict:
        """Setup the tool

        Returns
        -------
        info: dict
            The setup info.
        """

        self._weight_graphs = []
        return super().setup()

    def _reset(
        self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]]
    ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]:
        """Reset the tool

        Parameters
        ----------
        graphs: list<MSCgraph>
            The msc graphs.
        weights: list<dict<str, tvm.nd.array>>
            The weights
        as_cache: bool
            Whether the graphs and weights are loaded from cache

        Returns
        -------
        graphs: list<MSCgraph>
            The msc graphs.
        weights: list<dict<str, tvm.nd.array>>
            The weights
        """

        graphs, weights = super()._reset(graphs, weights)
        self._main_wtypes, self._relation_wtypes = self._get_wtypes()
        assert self._main_wtypes, "main_wtypes should be given to build weight graphs"
        if self._weight_graphs:
            assert len(graphs) == len(
                self._weight_graphs
            ), "Graphs {} mismatch with weight graphs {}".format(
                len(graphs), len(self._weight_graphs)
            )
        else:
            self._weight_graphs = [
                _ffi_api.WeightGraph(graph, self._main_wtypes, self._relation_wtypes)
                for graph in graphs
            ]
            self._logger.debug(
                "%s reset %d weight graphs", self.tool_type(), len(self._weight_graphs)
            )
        if self.on_debug(2, in_forward=False):
            for idx, graph in enumerate(self._weight_graphs):
                self._logger.debug(
                    msc_utils.msg_block("WEIGHT_GRAPH[{}].INFO".format(idx), graph.inspect())
                )
        return graphs, weights

    def _get_wtypes(self) -> Tuple[Dict[str, List[str]], Dict[str, str]]:
        """Get the weight types from options

        Returns
        -------
        main_wtypes: dict<str,list<str>>
            The main weight types.
        relation_wtypes: dict<str, str>
            The relation weight types
        """

        raise NotImplementedError("_get_wtypes is not implemented in WeightTool")

    def load_cache(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict):
        """Save runner to cache

        Parameters
        -------
        cache_dir: MSCDirectory
            cache path for save/load info
        cache_info: dict
            The cache_info
        """

        assert (
            "weight_graphs" in cache_info
        ), "weight_graphs should be given in cache_info, get " + str(cache_info)
        self._weight_graphs = [
            WeightGraph.from_json(cache_dir.relpath(f)) for f in cache_info["weight_graphs"]
        ]
        self._logger.debug(
            "%s load %d weight graphs from %s",
            self.tool_type(),
            len(self._weight_graphs),
            cache_dir,
        )

    def save_cache(self, cache_dir: msc_utils.MSCDirectory) -> dict:
        """Save runner to cache

        Parameters
        -------
        cache_dir: MSCDirectory
            cache path for save/load info

        Returns
        -------
        cache_info: dict
            The cache_info.
        """

        cache_info = {"weight_graphs": [g.name + "_graph.json" for g in self._weight_graphs]}
        with cache_dir:
            for graph, f_path in zip(self._weight_graphs, cache_info["weight_graphs"]):
                with open(f_path, "w") as f_graph:
                    f_graph.write(graph.to_json())
        return cache_info

    def visualize(self, visual_dir: msc_utils.MSCDirectory):
        """Visualize MSCGraphs

        Parameters
        -------
        visual_dir: MSCDirectory
            Visualize path for saving graph
        """

        for w_graph in self._weight_graphs:
            w_graph.visualize(visual_dir.relpath(w_graph.name + ".prototxt"))

    def get_w_nodes(self) -> Iterable[WeightJoint]:
        """Get all the weight nodes in the weight_graphs.

        Returns
        -------
        nodes: generator<WeightJoint>
            The generator of weight nodes.
        """

        for g in self._weight_graphs:
            for n in g.get_nodes():
                yield n

    def has_w_node(self, name: str) -> bool:
        """Check if name in weight_graphs.

        Parameters
        ----------
        name: string
            The name of the node.

        Returns
        -------
        has_node: bool
            Whether node in weight_graphs.
        """

        for g in self._weight_graphs:
            if g.has_node(name):
                return True
        return False

    def find_w_node(self, name: str) -> WeightJoint:
        """Find weight node by name.

        Parameters
        ----------
        name: string
            The name of the node.

        Returns
        -------
        node: WeightJoint
            The found node.
        """

        for g in self._weight_graphs:
            if g.has_node(name):
                return g.find_node(name)
        raise Exception("Can not find node {} from graphs".format(name))

    def _get_io_axes(self, w_node: WeightJoint) -> Tuple[int, int]:
        """Get the input output axes

        Parameters
        ----------
        w_node: WeightJoint
            The weight node.

        Returns
        -------
        axes: (int, int)
            The input output axis.
        """

        if w_node.weight.ndim == 1:
            return 0, 0
        if w_node.has_attr("in_axis") and w_node.has_attr("out_axis"):
            return int(w_node.get_attr("in_axis")), int(w_node.get_attr("out_axis"))
        in_axis, out_axis = w_node.weight.layout_of("I"), w_node.weight.layout_of("O")
        if in_axis >= 0 and out_axis >= 0:
            return in_axis, out_axis
        if w_node.weight.layout_of("C") >= 0:
            return w_node.weight.layout_of("C"), w_node.weight.layout_of("C")
        raise Exception("Can not infer in_axis/out_axis from " + str(w_node))

    @classmethod
    def tool_type(cls):
        return ToolType.WEIGHT
